{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Theano 实例：Softmax 回归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST 数据集的下载和导入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[MNIST 数据集](http://yann.lecun.com/exdb/mnist/) 是一个手写数字组成的数据集，现在被当作一个机器学习算法评测的基准数据集。\n",
    "\n",
    "这是一个下载并解压数据的脚本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting download_mnist.py\n"
     ]
    }
   ],
   "source": [
    "%%file download_mnist.py\n",
    "import os\n",
    "import os.path\n",
    "import urllib\n",
    "import gzip\n",
    "import shutil\n",
    "\n",
    "if not os.path.exists('mnist'):\n",
    "    os.mkdir('mnist')\n",
    "\n",
    "def download_and_gzip(name):\n",
    "    if not os.path.exists(name + '.gz'):\n",
    "        urllib.urlretrieve('http://yann.lecun.com/exdb/' + name + '.gz', name + '.gz')\n",
    "    if not os.path.exists(name):\n",
    "        with gzip.open(name + '.gz', 'rb') as f_in, open(name, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "            \n",
    "download_and_gzip('mnist/train-images-idx3-ubyte')\n",
    "download_and_gzip('mnist/train-labels-idx1-ubyte')\n",
    "download_and_gzip('mnist/t10k-images-idx3-ubyte')\n",
    "download_and_gzip('mnist/t10k-labels-idx1-ubyte')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以运行这个脚本来下载和解压数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%run download_mnist.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用如下的脚本来导入 MNIST 数据，源码地址：\n",
    "\n",
    "https://github.com/Newmu/Theano-Tutorials/blob/master/load.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting load.py\n"
     ]
    }
   ],
   "source": [
    "%%file load.py\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "datasets_dir = './'\n",
    "\n",
    "def one_hot(x,n):\n",
    "\tif type(x) == list:\n",
    "\t\tx = np.array(x)\n",
    "\tx = x.flatten()\n",
    "\to_h = np.zeros((len(x),n))\n",
    "\to_h[np.arange(len(x)),x] = 1\n",
    "\treturn o_h\n",
    "\n",
    "def mnist(ntrain=60000,ntest=10000,onehot=True):\n",
    "\tdata_dir = os.path.join(datasets_dir,'mnist/')\n",
    "\tfd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))\n",
    "\tloaded = np.fromfile(file=fd,dtype=np.uint8)\n",
    "\ttrX = loaded[16:].reshape((60000,28*28)).astype(float)\n",
    "\n",
    "\tfd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))\n",
    "\tloaded = np.fromfile(file=fd,dtype=np.uint8)\n",
    "\ttrY = loaded[8:].reshape((60000))\n",
    "\n",
    "\tfd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))\n",
    "\tloaded = np.fromfile(file=fd,dtype=np.uint8)\n",
    "\tteX = loaded[16:].reshape((10000,28*28)).astype(float)\n",
    "\n",
    "\tfd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))\n",
    "\tloaded = np.fromfile(file=fd,dtype=np.uint8)\n",
    "\tteY = loaded[8:].reshape((10000))\n",
    "\n",
    "\ttrX = trX/255.\n",
    "\tteX = teX/255.\n",
    "\n",
    "\ttrX = trX[:ntrain]\n",
    "\ttrY = trY[:ntrain]\n",
    "\n",
    "\tteX = teX[:ntest]\n",
    "\tteY = teY[:ntest]\n",
    "\n",
    "\tif onehot:\n",
    "\t\ttrY = one_hot(trY, 10)\n",
    "\t\tteY = one_hot(teY, 10)\n",
    "\telse:\n",
    "\t\ttrY = np.asarray(trY)\n",
    "\t\tteY = np.asarray(teY)\n",
    "\n",
    "\treturn trX,teX,trY,teY"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## softmax 回归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Softmax` 回归相当于 `Logistic` 回归的一个一般化，`Logistic` 回归处理的是两类问题，`Softmax` 回归处理的是 `N` 类问题。\n",
    "\n",
    "`Logistic` 回归输出的是标签为 1 的概率（标签为 0 的概率也就知道了），对应地，对 N 类问题 `Softmax` 输出的是每个类对应的概率。\n",
    "\n",
    "具体的内容，可以参考 `UFLDL` 教程：\n",
    "\n",
    "http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using gpu device 1: Tesla C2075 (CNMeM is disabled)\n"
     ]
    }
   ],
   "source": [
    "import theano\n",
    "from theano import tensor as T\n",
    "import numpy as np\n",
    "from load import mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们来看它具体的实现。\n",
    "\n",
    "这两个函数一个是将数据转化为 `GPU` 计算的类型，另一个是初始化权重："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def floatX(X):\n",
    "    return np.asarray(X, dtype=theano.config.floatX)\n",
    "\n",
    "def init_weights(shape):\n",
    "    return theano.shared(floatX(np.random.randn(*shape) * 0.01))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Softmax` 的模型在 `theano` 中已经实现好了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 4)\n",
      "[ 1.00000012  1.          1.        ]\n"
     ]
    }
   ],
   "source": [
    "A = T.matrix()\n",
    "\n",
    "B = T.nnet.softmax(A)\n",
    "\n",
    "test_softmax = theano.function([A], B)\n",
    "\n",
    "a = floatX(np.random.rand(3, 4))\n",
    "\n",
    "b = test_softmax(a)\n",
    "\n",
    "print b.shape\n",
    "\n",
    "# 行和\n",
    "print b.sum(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`softmax` 函数会按照行对矩阵进行 `Softmax` 归一化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所以我们的模型为："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model(X, w):\n",
    "    return T.nnet.softmax(T.dot(X, w))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trX, teX, trY, teY = mnist(onehot=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义变量，并初始化权重："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = T.fmatrix()\n",
    "Y = T.fmatrix()\n",
    "\n",
    "w = init_weights((784, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义模型输出和预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "py_x = model(X, w)\n",
    "y_pred = T.argmax(py_x, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "损失函数为多类的交叉熵，这个在 `theano` 中也被定义好了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))\n",
    "gradient = T.grad(cost=cost, wrt=w)\n",
    "update = [[w, w - gradient * 0.05]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "编译 `train` 和 `predict` 函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train = theano.function(inputs=[X, Y], outputs=cost, updates=update, allow_input_downcast=True)\n",
    "predict = theano.function(inputs=[X], outputs=y_pred, allow_input_downcast=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "迭代 100 次，测试集正确率为 0.925："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "000 0.8862\n",
      "001 0.8985\n",
      "002 0.9042\n",
      "003 0.9084\n",
      "004 0.9104\n",
      "005 0.9121\n",
      "006 0.9121\n",
      "007 0.9142\n",
      "008 0.9158\n",
      "009 0.9163\n",
      "010 0.9162\n",
      "011 0.9166\n",
      "012 0.9171\n",
      "013 0.9176\n",
      "014 0.9182\n",
      "015 0.9182\n",
      "016 0.9184\n",
      "017 0.9188\n",
      "018 0.919\n",
      "019 0.919\n",
      "020 0.9194\n",
      "021 0.9201\n",
      "022 0.9204\n",
      "023 0.9203\n",
      "024 0.9205\n",
      "025 0.9207\n",
      "026 0.9207\n",
      "027 0.9209\n",
      "028 0.9214\n",
      "029 0.9213\n",
      "030 0.9212\n",
      "031 0.9211\n",
      "032 0.9217\n",
      "033 0.9217\n",
      "034 0.9217\n",
      "035 0.922\n",
      "036 0.9222\n",
      "037 0.922\n",
      "038 0.922\n",
      "039 0.9218\n",
      "040 0.9219\n",
      "041 0.9223\n",
      "042 0.9225\n",
      "043 0.9226\n",
      "044 0.9227\n",
      "045 0.9225\n",
      "046 0.9227\n",
      "047 0.9231\n",
      "048 0.9231\n",
      "049 0.9231\n",
      "050 0.9232\n",
      "051 0.9232\n",
      "052 0.9231\n",
      "053 0.9231\n",
      "054 0.9233\n",
      "055 0.9233\n",
      "056 0.9237\n",
      "057 0.9239\n",
      "058 0.9239\n",
      "059 0.9239\n",
      "060 0.924\n",
      "061 0.9242\n",
      "062 0.9242\n",
      "063 0.9243\n",
      "064 0.9243\n",
      "065 0.9244\n",
      "066 0.9244\n",
      "067 0.9244\n",
      "068 0.9245\n",
      "069 0.9244\n",
      "070 0.9244\n",
      "071 0.9245\n",
      "072 0.9244\n",
      "073 0.9243\n",
      "074 0.9243\n",
      "075 0.9244\n",
      "076 0.9243\n",
      "077 0.9242\n",
      "078 0.9244\n",
      "079 0.9244\n",
      "080 0.9243\n",
      "081 0.9242\n",
      "082 0.9239\n",
      "083 0.9241\n",
      "084 0.9242\n",
      "085 0.9243\n",
      "086 0.9244\n",
      "087 0.9243\n",
      "088 0.9243\n",
      "089 0.9244\n",
      "090 0.9246\n",
      "091 0.9246\n",
      "092 0.9246\n",
      "093 0.9247\n",
      "094 0.9246\n",
      "095 0.9246\n",
      "096 0.9246\n",
      "097 0.9246\n",
      "098 0.9246\n",
      "099 0.9248\n"
     ]
    }
   ],
   "source": [
    "for i in range(100):\n",
    "    for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):\n",
    "        cost = train(trX[start:end], trY[start:end])\n",
    "    print \"{0:03d}\".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
